import os
import torch
import platform
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN


from data.sampler import BalancedSampler as BS


cDT = {'cifar10': CIFAR10,
       'cifar10H': CIFAR10,
       'cifar100': CIFAR100,
       'svhn': SVHN,
       'mnist': MNIST,}
DATASETS = ['cifar10', 'cifar10H', 'cifar100', 'svhn', 'stl10', 'mnist', 'imagenet']
data_path = {'cifar10': "./data/cifar10/" if platform.system()=="Windows" else '~/data/cifar10/',
            'cifar10H': "./data/cifar10/" if platform.system()=="Windows" else '~/data/cifar10/',
             'cifar100': "./data/cifar100/" if platform.system()=="Windows" else '~/data/cifar100/',
             'svhn': "./data/svhn/" if platform.system()=="Windows" else '~/data/svhn/',
             'mnist': "./data/mnist/" if platform.system() == "Windows" else '~/data/mnist/',
             'stl10': "./data/stl10/" if platform.system()=="Windows" else '~/data/stl10/',
             'imagenet': '~/data/imagenet/',
             }
n_cls = {'cifar10': 10,
        'cifar10H': 10,
       'cifar100': 100,
       'svhn': 10,
       'mnist' : 10,
       'stl10': 10,
       'imagenet': -1,}
normalization_infos = {
    'cifar10' : [(0.4914, 0.4822, 0.4465),
                 (0.2023, 0.1994, 0.2010)], # MEAN, STD
    'cifar100' : [(0.5071, 0.4867, 0.4408),
                  (0.2675, 0.2565, 0.2761)], # MEAN, STD
    'cifar10H': [(0.4914, 0.4822, 0.4465),
                 (0.2023, 0.1994, 0.2010)],  # MEAN, STD
    'svhn' : [(0.4376821, 0.4437697, 0.47280442),
              (0.19803012, 0.20101562, 0.19703614)],
    'mnist' : [(0.1307,), (0.3081,)],
    'celeba' : [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]
}
sh = {'cifar10': (3, 32, 32),
        'cifar10H': (3, 32, 32),
       'cifar100': (3, 32, 32),
       'svhn': (3, 32, 32),
       'mnist' : (1, 28, 28),}
input_range = {'cifar10': (-1, 1),
        'cifar10H': (-1, 1),
       'cifar100': (-1, 1),
       'svhn': (-1, 1),
       'mnist' : (-1, 1),}


# def calculate_mean_std(data):
    # m_a, s_a = normalization_infos['cifar10']
    # m_b, s_b = normalization_infos[data]

    # m_a, s_a = np.array(m_a), np.array(s_a)
    # m_b, s_b = np.array(m_b), np.array(s_b)

    # mean = m_b - (2 * m_a - 1) * s_b / (2 * s_a)
    # std = s_b / (2 * s_a)
    # return transforms.Normalize(mean=mean, std=std)


def normalize():
    mean, std = (.5, .5, .5), (.5, .5, .5)
    return transforms.Normalize(mean=mean, std=std)


def get_transform(sigma=0.):
    transform_test = transforms.Compose(
       [transforms.ToTensor(),
        normalize(),
        lambda x: x + sigma * torch.randn_like(x)])
    return transform_test


def transform(args, dt, type):
    if type == 'tr':
        if args.crl:
            tr_lst = [transforms.RandomResizedCrop(size=args.cropsize, scale=(0.2, 1.)),
                      transforms.RandomHorizontalFlip(),
                      transforms.RandomApply([
                        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                        ], p=0.8),
                      transforms.RandomGrayscale(p=0.2),]
        elif dt == 'mnist':
            tr_lst = list()
        elif dt == 'svhn':
            # tr_lst = [transforms.ColorJitter(brightness=0.2, contrast=(0.2,2.0)),
            # transforms.RandomCrop(32, padding=4)]
            tr_lst = [transforms.Pad(4, padding_mode="reflect"),
                      transforms.RandomCrop(32) ]
        elif 'sc' in args.arch:
            # tr_lst = [transforms.RandomResizedCrop((224,224), scale=(0.15,1.0), ratio=(0.75, 1.3333), interpolation=1),
                      # transforms.RandomRotation(degrees=(-90, 90)),
                      # transforms.ColorJitter(brightness=(0.6,1.4), contrast=(0.6,1.4), saturation=(0.6,1.4), hue=(-0.1,0.1)),
                      # transforms.ToTensor(),
            #           normalize()]
            tr_lst = [transforms.Resize((224,224))]
        else:
            tr_lst = [transforms.Pad(4, padding_mode="reflect"),
                      transforms.RandomCrop(32),
                      transforms.RandomHorizontalFlip()]
        tr_lst.append(transforms.ToTensor())
        tr_lst.append(normalize())
        if args.x_noise:
            tr_lst.append(lambda x: x + args.x_sigma * torch.randn_like(x))
        transform = transforms.Compose(tr_lst)
    elif type == 'vl':
        if 'sc' in args.arch:
            transform = transforms.Compose([transforms.Resize((224,224)),
                     transforms.ToTensor(),
                     normalize()])
        else:
            transform = transforms.Compose([transforms.ToTensor(),
                                        normalize()])
    else:
        raise ValueError
    return transform


def get_dl_tr(args, shuffle=True, no_aug=False, is_uc=False):
    return _get_dl(args, True, shuffle, no_aug, is_uc)


def get_dl_vl(args, shuffle=False, no_aug=False):
    return _get_dl(args, False, shuffle, no_aug)

def get_dl_robust(args, no_aug=False):
    from data.ds import Robust_ds
    ds = Robust_ds()
    return DataLoader(
        ds, batch_size=args.bsz, shuffle=False)


def _get_dl(args, is_train, shuffle, no_aug, is_uc=False):
    sdt = args.dataset
    bsz = args.bsz if is_train else args.bsz_vl
    if is_uc: bsz = bsz//2
    tr_vl = 'tr' if is_train and not no_aug else 'vl'
    if sdt == 'svhn':
        s = ['test', 'train'][is_train]
        dt = cDT[sdt](root=data_path[sdt], download=True,
                      transform=transform(args, sdt, tr_vl),
                      split=s)
    else:
        dt = cDT[sdt](root=data_path[sdt]+'train/', train=is_train, download=True,
                        transform=transform(args, sdt, tr_vl))
    if args.debug:
        dt.data = dt.data[:64]
    if args.bs:
        sampler = BS(dt, args)
        shuffle = False
    else:
        sampler = None
    if (sdt == 'cifar10h' or sdt == 'cifar10H') and not is_train:
        human_labels = np.load(os.path.expanduser(os.path.join(data_path[sdt],
                             'cifar10h-probs.npy'))).argmax(-1);
        dt.targets = human_labels
    droplast = True if args.ebm else False
    dl = torch.utils.data.DataLoader(
        dt,
        sampler=sampler,
        batch_size=bsz,
        num_workers=args.num_workers,
        pin_memory=True,
        shuffle=shuffle,
        drop_last=droplast
    )
    print(f"Data loaded with {len(dt)} {tr_vl} imgs.")
    return dl


def cycle(loader):
    while True:
        for data in loader:
            yield data
